import jax.numpy as jnp
import jax.nn as nn
import numpy as np
import jax


from flax import linen as nn
from flax.linen.initializers import constant, orthogonal
from typing import NamedTuple, Optional,Any,Sequence,Dict
from jax.flatten_util import ravel_pytree
from flax.core.frozen_dict import freeze, unfreeze

class VannillaRNN(nn.Module):
    d_model:int
    
    @nn.compact
    def __call__(self,inputs,last_hidden):
        #inputs x_t: input_dim, last_hidden: d_model
        last_hidden=last_hidden[0]
        inputs_to_hidden=nn.Dense(self.d_model)(inputs)
        hidden_to_hidden=nn.Dense(self.d_model)(last_hidden)
        out=jax.nn.tanh(inputs_to_hidden+hidden_to_hidden)
        return out,(out,)
    
    @staticmethod
    def initialize_state(d_model):
        return (jnp.zeros((d_model,)),)



class UORORNN(nn.Module):
    mdl:Any
    mdl_kwargs:Dict
    eps:float=1e-5
    
    @nn.compact
    def __call__(self, inputs,last_hidden,memory_grads=None,ret_mem_grad_ax=0):
        """
            ret_mem_grad_ax: Configures which of new memory grads to return, should be either -1 or between 0 and T-1)
        """
        def f(mdl, inputs,last_hidden,memory_grads):
            return mdl(inputs,last_hidden)
            
        def fwd(mdl, inputs,last_hidden,memory_grads):
            output,vjp_fn=nn.vjp(f, mdl, inputs,last_hidden,memory_grads)
            return output,(vjp_fn,memory_grads)

        def bwd(residuals,y_t):
            vjp_fn,memory_grads=residuals
            params_t_1, *inputs_t = vjp_fn(y_t) 
            vjp_htminus1=inputs_t[1] 
            htilde_tminus1,htildetheta_tminus1=memory_grads
            h_infl=jax.tree_map(lambda x,y:jnp.tensordot(x,y,axes=y.ndim),vjp_htminus1,htilde_tminus1)
            #h_infl is a pytree iterate over it and htildetheta_tminus1 to generate a as many 
            h_infl=jax.tree_util.tree_flatten(h_infl)[0]
            params_addn=[]
            for infl in h_infl:
                params_addn.append(jax.tree_map(lambda x:x*infl,htildetheta_tminus1))
            params_t=jax.tree_map(lambda *args:jnp.sum(jnp.stack(args),axis=0),params_t_1,*params_addn)
            return (params_t, *inputs_t)

        uoro_grad = nn.custom_vjp(
            f, forward_fn=fwd, backward_fn=bwd)
        mdl_fn=self.mdl(**self.mdl_kwargs)
        if memory_grads is not None:
            htilde_tminus1,htildetheta_tminus1=memory_grads
            #Update to new tilde
            def jvp_fn(last_hidden):
                out,new_hidden=mdl_fn(inputs,last_hidden)
                return new_hidden
            _,htilde_t_b=jax.jvp(jvp_fn ,(last_hidden,),(htilde_tminus1,))
            #New term
            htilde_t_a=jax.tree_map(lambda x:jax.random.choice(self.make_rng('random'),jnp.array([+1.0,-1.0]),x.shape),htilde_t_b)
            
            htildetheta_t_b=htildetheta_tminus1
            _,hvjpfn=nn.vjp(lambda mdl: mdl(inputs,last_hidden)[1], mdl_fn,)
            htildetheta_t_a,=hvjpfn(htilde_t_a)

            #Calculate the variance minimization terms
            rho0_a=jnp.linalg.norm(ravel_pytree(htildetheta_t_b)[0])
            rho0_b=jnp.linalg.norm(ravel_pytree(htilde_t_b)[0])
            rho0=jnp.sqrt((rho0_a+self.eps)/(rho0_b+self.eps))
            rho1_a=jnp.linalg.norm(ravel_pytree(htildetheta_t_a)[0])
            rho1_b=jnp.linalg.norm(ravel_pytree(htilde_t_a)[0])
            rho1=jnp.sqrt((rho1_a+self.eps)/(rho1_b+self.eps))
            #Calculate the new memory gradients
            x=jax.tree_map(lambda x: x*rho0,htilde_t_b)
            y=jax.tree_map(lambda x: x*rho1,htilde_t_a)
            htilde_t=jax.tree_map(lambda x,y:x+y,x,y)

            irho0=1/rho0
            irho1=1/rho1
            x=jax.tree_map(lambda x: x*irho0,htildetheta_t_b)
            y=jax.tree_map(lambda x: x*irho1,htildetheta_t_a)
            x=dict(freeze(x))
            y=dict(freeze(y))
            htildetheta_t=jax.tree_map(lambda x,y:x+y,x,y)
            new_memory_grads=(htilde_t,htildetheta_t)
        else: new_memory_grads=memory_grads
        return uoro_grad(mdl_fn, inputs,last_hidden,memory_grads),new_memory_grads
    

    @staticmethod
    def initialize_memory(mdl,mdl_kwargs,sample_input):
        last_hidden=mdl.initialize_state(**mdl_kwargs)
        params=mdl(**mdl_kwargs).init(jax.random.PRNGKey(0), sample_input,last_hidden)
        h_tilde=jax.tree_map(lambda x: jnp.zeros_like(x),last_hidden)
        h_tilde_theta=dict(jax.tree_map(lambda x: jnp.zeros_like(x),params))
        memory_grads=(h_tilde,h_tilde_theta)
        return last_hidden,memory_grads
    


if __name__=='__main__':
    last_hidden,memory_grads=UORORNN.initialize_memory(VannillaRNN,{'d_model':32},jnp.zeros((1,32)))
    model=UORORNN(VannillaRNN,{'d_model':32})
    rngs={
        'params':jax.random.PRNGKey(0),
        'random':jax.random.PRNGKey(0)}
    params=model.init(rngs,jnp.zeros((32)),last_hidden,memory_grads)

    
    #Test forward pass
    #We need to split rng during each step during forward pass, make sure same rng is propagated also during backward pass
    #RNG to use for sampling random signs
    sample_rng=jax.random.PRNGKey(0)

    
    def applyuoro(params,inputs,last_hidden,memory_grads,rng):
        for i in range(inputs.shape[0]):
            new_rng, _ = jax.random.split(rng)
            (new_out,new_hidden),memory_grads=model.apply(params,inputs[i],last_hidden,memory_grads,rngs={'random':new_rng})
            last_hidden=new_hidden
            print(new_out)
        return ((1-new_out)**2).sum()
    
    #Test forward pass
    applyuoro(params,jnp.ones((2,32)),last_hidden,memory_grads,sample_rng)
    
    #Test backward pass
    x=jax.grad(applyuoro,argnums=0)(params,jnp.ones((1,32)),last_hidden,memory_grads,sample_rng)
    print(x)